import pandas as pd
import numpy as np
import os
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_validate
from sklearn.metrics import classification_report, confusion_matrix
from sklearn import metrics
import pickle
from sklearn.utils import shuffle
from sklearn.model_selection import KFold
from scipy.stats import sem
from sklearn.decomposition import PCA
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score
import datetime
import sys
import argparse

"""
This script performs MaMaDroid's classification using Random Forest classifier.
Inputs are described in parseargs function.
Scores (Precision, Recall, F1-score) are calculated using 10-folds cross-validation with and without PCA for family and package modes. 
The Outputs are a text file on which scores are written.
"""


def parseargs():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-pf",
        "--fami",
        help=
        "The path to the CSV files of family mode, for drebin, 2013, 2014, 2015, 2016, oldbenign, and newbenign datasets. These files are generated by MaMaStat.py script",
        type=str,
        required=True
    )
    parser.add_argument(
        "-pp",
        "--pack",
        help=
        "The path to the CSV files of package mode, for drebin, 2013, 2014, 2015, 2016, oldbenign, and newbenign datasets. These files are generated by MaMaStat.py script",
        type=str,
        required=True
    )
    parser.add_argument(
        "-fs",
        "--filescores",
        help="The name of the file on which the results will be written",
        type=str,
        required=True
    )
    parser.add_argument(
        "-sd",
        "--seed",
        help="The seed to fix for the experiments",
        type=int,
        required=False,
        default=np.random.randint(0, 2**32 - 1)
    )
    args = parser.parse_args()
    return args


#The pairs of datasets used
data = [
    ["2016", "newbenign"], ["2015", "newbenign"], ["2014", "oldbenign"],
    ["2014", "newbenign"], ["2013", "oldbenign"], ["drebin", "oldbenign"]
]

mode_family = "family"
mode_package = "package"

scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score),
    'recall': make_scorer(recall_score),
    'f1_score': make_scorer(f1_score)
}

#columns that needs to be removed from the CSV files, to match what is stated in the paper
DROP_COLUMNS = [
    ' org.w3c.dom.Tocom.google.', ' org.w3c.dom.Toorg.xml.',
    ' org.w3c.dom.Toorg.apache.', ' org.w3c.dom.Tojavax.',
    ' org.w3c.dom.Tojava.', ' org.w3c.dom.Toandroid.',
    ' org.w3c.dom.Toorg.w3c.dom.', ' org.w3c.dom.Toorg.json.',
    ' org.w3c.dom.Todalvik.', ' org.w3c.dom.Toselfdefined',
    ' org.w3c.dom.Toobfuscated', ' org.json.Tocom.google.',
    ' org.json.Toorg.xml.', ' org.json.Toorg.apache.', ' org.json.Tojavax.',
    ' org.json.Tojava.', ' org.json.Toandroid.', ' org.json.Toorg.w3c.dom.',
    ' org.json.Toorg.json.', ' org.json.Todalvik.', ' org.json.Toselfdefined',
    ' org.json.Toobfuscated', ' com.google.Toorg.w3c.dom.',
    ' com.google.Toorg.json.', ' org.xml.Toorg.w3c.dom.',
    ' org.xml.Toorg.json.', ' org.apache.Toorg.w3c.dom.',
    ' org.apache.Toorg.json.', ' javax.Toorg.w3c.dom.', ' javax.Toorg.json.',
    ' java.Toorg.w3c.dom.', ' java.Toorg.json.', ' android.Toorg.w3c.dom.',
    ' android.Toorg.json.', ' dalvik.Toorg.w3c.dom.', ' dalvik.Toorg.json.',
    ' selfdefinedToorg.w3c.dom.', ' selfdefinedToorg.json.',
    ' obfuscatedToorg.w3c.dom.', ' obfuscatedToorg.json.',
    ' dalvik.Tocom.google.', ' dalvik.Toorg.xml.', ' dalvik.Toorg.apache.',
    ' dalvik.Tojavax.', ' dalvik.Tojava.', ' dalvik.Toandroid.',
    ' dalvik.Todalvik.', ' dalvik.Toselfdefined', ' dalvik.Toobfuscated',
    ' com.google.Todalvik.', ' javax.Todalvik.', ' java.Todalvik.',
    ' android.Todalvik.', ' selfdefinedTodalvik.', ' obfuscatedTodalvik.',
    ' org.apache.Todalvik.'
]


#clean_file() takes in input the path to the CSV features file, the path to save the output, and the mode of operation
def clean_file(path_file, path_save, mode):
    all_file = ""
    with open(path_file, 'r') as my_file:
        for line in my_file:
            line = line.replace("[", "")
            line = line.replace("]", "")
            line = line.replace("'", "")
            all_file = all_file + str(line)
    with open(
        path_save, 'w'
    ) as save_file:  # Save the updates back into a cleaned up file
        save_file.write(all_file)
    if mode == "family":
        #if the operation mode is family, perform the filtering step (remove DROP_COLUMNS) from the cleaned file
        filter_family(path_save)


#This function opens the file(from path_file), removes the "DROP_COLUMNS" columns, and saves it back with the same name
def filter_family(path_file):
    dataset = pd.read_csv(path_file, index_col=0)
    dataset = dataset.drop(DROP_COLUMNS, axis=1)
    dataset.to_csv(path_file)


def csv_file_to_nparray(fname):
    def f_gen_withouth_first_row(fname):
        # read a CSV files and outputs a flow of line without the first col
        f_in = open(fname, 'r')
        for line in f_in:
            yield remove_first_col(line)
        f_in.close()

    def remove_first_col(line):
        return line[line.index(',') + 1:]

    return np.loadtxt(
        f_gen_withouth_first_row(fname),
        dtype=np.float32,
        delimiter=',',
        converters=None,
        skiprows=1
    )


#This function takes in input two csv files (malware and goodware), and returns
def constructDataframes(mal_data, good_data):
    dataset_mal = csv_file_to_nparray(mal_data)
    dataset_good = csv_file_to_nparray(good_data)
    Mal_labels = np.ones((dataset_mal.shape[0], 1), dtype=np.float32)
    dataset_mal = np.hstack((dataset_mal, Mal_labels))
    Good_labels = np.full(
        (dataset_good.shape[0], 1), dtype=np.float32, fill_value=-1
    )
    dataset_good = np.hstack((dataset_good, Good_labels))
    # At this stage XY contains both X and Y
    XY = np.vstack((dataset_mal, dataset_good))
    np.random.shuffle(XY)
    # Let's return a (X, Y) tuple
    return XY[:, :-1], XY[:, -1]


def Classification(x, y, data, mode):
    if mode == "family":
        # hyper-parameters of original authors for family mode
        model_rf = RandomForestClassifier(max_depth=8, n_estimators=51)
    elif mode == "package":
        # hyper-parameters of original authors for package mode
        model_rf = RandomForestClassifier(max_depth=64, n_estimators=101)
    cv = KFold(10, shuffle=True)  # prepare 10 folds
    # perform the cross validation
    scores = cross_validate(model_rf, x, y, cv=cv, scoring=scoring)
    #retrieve the scores, and calculate the average
    results = "Mean score " + data + ": Precision: {0:.3f} (+/-{1:.3f}) | Recall: {2:.3f} (+/-{3:.3f}) | f1 score: {4:.3f} (+/-{5:.3f}) ".format(
        np.mean(scores["test_precision"]), sem(scores["test_precision"]),
        np.mean(scores["test_recall"]), sem(scores["test_recall"]),
        np.mean(scores["test_f1_score"]), sem(scores["test_f1_score"])
    )
    return results


def Classification_PCA(x, y, data, mode):
    pca = PCA(n_components=10)  #10 selected component from original paper
    x = pca.fit_transform(x)
    if mode == "family":
        model_rf = RandomForestClassifier(max_depth=8, n_estimators=51)
    elif mode == "package":
        model_rf = RandomForestClassifier(max_depth=64, n_estimators=101)
    cv = KFold(10, shuffle=True)
    scores = cross_validate(model_rf, x, y, cv=cv, scoring=scoring)
    resultsPCA = "Mean score pca " + data + ": Precision: {0:.3f} (+/-{1:.3f}) | Recall: {2:.3f} (+/-{3:.3f}) | f1 score: {4:.3f} (+/-{5:.3f}) ".format(
        np.mean(scores["test_precision"]), sem(scores["test_precision"]),
        np.mean(scores["test_recall"]), sem(scores["test_recall"]),
        np.mean(scores["test_f1_score"]), sem(scores["test_f1_score"])
    )
    return resultsPCA


def CalculateScores(path, data_mal, data_good, mode):
    """
    This function calls the previous function in order to calculate the scores for all the pairs of datasets
    data_mal is the name of a malware dataset(eg., 2013), data_good is the name of a goodawre dataset. path and mode depends on the mode of operation
    """
    #construct the path for data_mal and data_good
    mal, good =  os.path.join(path, data_mal), os.path.join(path, data_good)
    #construct the path for data_mal CSV file, and the path to save the cleaned file
    path_file_mal, path_mal_save = str(mal +
                                       ".csv"), str(mal + "_" + "clean.csv")
    path_file_good, path_good_save = str(good + ".csv"
                                         ), str(good + "_" + "clean.csv")
    #check if the malware cleaned file exists
    if os.path.exists(path_mal_save) == False:
        mal = clean_file(path_file_mal, path_mal_save, mode)
    if os.path.exists(path_good_save) == False:
        good = clean_file(path_file_good, path_good_save, mode)
    x, y = constructDataframes(path_mal_save, path_good_save)
    scores = Classification(x, y, data_mal + data_good, mode)
    scores_pca = Classification_PCA(x, y, data_mal + data_good, mode)
    return scores, scores_pca


if __name__ == '__main__':
    Args = parseargs()  #retrieve the parameters
    path_family = Args.fami
    path_package = Args.pack
    file_scores = Args.filescores
    SEED = Args.seed
    np.random.seed(SEED)
    #scores are stored in this file. For each pair of dataset, the file will have Precision, Recall, and F1 score for family and package mode, with and without PCA.
    outp = open(file_scores, "a")
    #calculate the scores for each pair of datasets in data list
    for i in data:
        data_mal = i[0]
        data_good = i[1]
        scores, scores_pca = CalculateScores(
            path_family, data_mal, data_good, mode_family
        )
        outp.write(data_mal + "__" + data_good + "__" + mode_family + "\n")
        outp.write(scores + "\n")
        outp.write(scores_pca + "\n")
        outp.flush()
        scores, scores_pca = CalculateScores(
            path_package, data_mal, data_good, mode_package
        )
        outp.write(data_mal + "__" + data_good + "__" + mode_package + "\n")
        outp.write(scores + "\n")
        outp.write(scores_pca + "\n")
        outp.flush()
    outp.close()
